import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

import argparse
import time
import warnings

import sys
sys.path.append('../code')
from classification import *

from pathlib import Path
import pickle

def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--params", nargs='+', default=None)

    args = parser.parse_args()

    return args

###############################################################################
# BEGIN MAIN FUNCTION
###############################################################################
if __name__ == '__main__':
    warnings.simplefilter('ignore')
    
    # set the seed
    args = parse_commandline()
    seed = int(args.params[0])
    np.random.seed(seed)

    # set the number of bootstrap samples
    n_bs = 10000

    # whether to balance the train set size between the groups
    ## 1.0 means perfectly balanced
    ## None means no balancing
    # ratio = 1.0
    ratio = None
    test_size = 0.33
    delta = 0.01

    # load the data
    df = pd.read_csv('../data/diabetes.csv')
    Y_columns = [y for y in df.columns if 'admit' in y]
    Y_column = 'readmit_binary'
    G_column = 'female'
    X_columns = df.columns.drop(Y_columns)
    X_columns = X_columns.drop(G_column)
    
    # train models
    clf_b = LogisticRegression(max_iter=1000)
    clf_r = LogisticRegression(max_iter=1000)

    res, X_test_blue, y_test_blue, X_test_red, y_test_red =\
        classification_onefold(df, Y_column, G_column, X_columns,
                               clf_b, clf_r, seed=seed, ratio=ratio,
                               verbose=True, test_size=test_size,
                               standardize=False)

    clf_b = res['clf_b']
    clf_r = res['clf_r']
    e_b_B = res['res_blue']['e_b']
    e_r_B = res['res_blue']['e_r']
    e_b_R = res['res_red']['e_b']
    e_r_R = res['res_red']['e_r']
    
    # compute the p-values
    p_b, p_r, e_b_B_list, e_r_B_list, e_b_R_list, e_r_R_list =\
        bs_p_value_45(clf_b, clf_r, X_test_blue, y_test_blue, X_test_red,
                      y_test_red, e_b_B, e_r_B, e_b_R, e_r_R,
                      delta=0.01, n_bs=n_bs, seed=42)
    
    # save the results
    filepath = '/home/kop5674/FA_frontier/result/diabetes/'
    
    np.save(filepath + f'test_45_logreg_e_b_B_list_{seed}_{n_bs}.npy', e_b_B_list)
    np.save(filepath + f'test_45_logreg_e_r_B_list_{seed}_{n_bs}.npy', e_r_B_list)
    np.save(filepath + f'test_45_logreg_e_b_R_list_{seed}_{n_bs}.npy', e_b_R_list)
    np.save(filepath + f'test_45_logreg_e_r_R_list_{seed}_{n_bs}.npy', e_r_R_list)
    
    filename = f'test_45_logreg_{seed}_{ratio}.pkl'
    res = {'p_b': p_b, 'p_r': p_r, 'e_b_B': e_b_B, 'e_r_B': e_r_B,
           'e_b_R': e_b_R, 'e_r_R': e_r_R}
    
    print(res)
    
    with open(filepath + filename, 'wb') as f:
        pickle.dump(res, f)